{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prototype hybrid embedding : data-parallel frequent categories and model- parallel infrequent categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from copy import deepcopy\n",
    "\n",
    "np.random.seed(42) # guarantees the same data each execution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "num_slots = 10\n",
    "num_nodes = 2\n",
    "num_networks_per_node = 4\n",
    "embedding_vec_size = 128\n",
    "num_networks = num_networks_per_node*num_nodes\n",
    "local_batch_size = int((batch_size + num_networks - 1)/(num_networks))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten_data(data):\n",
    "    # concatenate all iterations\n",
    "    samples_data = np.concatenate([deepcopy(data[i][1])\n",
    "                                   for i in range(len(data))], axis=1)\n",
    "\n",
    "    # data dimensions\n",
    "    embedding_sizes = data[0][0]\n",
    "    num_tables = samples_data.shape[0]\n",
    "    num_samples = samples_data.shape[1]\n",
    "\n",
    "    # flatten and make categories unique\n",
    "    samples = np.zeros(num_tables * num_samples, dtype=np.int32)\n",
    "    category_index_offset = 0\n",
    "    for j in range(num_tables):\n",
    "        for i in range(num_samples):\n",
    "            samples[i * num_tables + j] = (category_index_offset\n",
    "                                          + samples_data[j, i])\n",
    "        category_index_offset += embedding_sizes[j]\n",
    "\n",
    "    return samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate synthetic data\n",
    "embed_sizes = np.random.randint(1, 2*batch_size, num_slots);\n",
    "num_categories = sum(embed_sizes)\n",
    "print(\"num_categories:\", num_categories)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\",\".join(map(str, embed_sizes)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = []\n",
    "num_batches = 15\n",
    "for i in range(num_batches):\n",
    "    batch = np.zeros((num_slots, batch_size))\n",
    "    for j in range(num_slots):\n",
    "        batch[j, :] = np.random.randint(0, embed_sizes[j], batch_size)\n",
    "    data.append((embed_sizes, batch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = flatten_data(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Gpu:\n",
    "\n",
    "    def __init__(self):\n",
    "        self.frequent_categories = None\n",
    "        self.category_frequent_index = None\n",
    "        self.frequent_sample_indices = None\n",
    "        self.frequent_embedding_vectors = None\n",
    "        self.frequent_partial_gradients = None\n",
    "        \n",
    "        self.num_infrequent = None\n",
    "        self.infrequent_embedding_vectors = None\n",
    "        self.category_location = None\n",
    "        self.model_indices = None\n",
    "        self.model_indices_offsets = None\n",
    "        self.network_indices = None\n",
    "        self.network_indices_offsets = None\n",
    "        self.ggpu_id = None\n",
    "        self.gpu_id = None\n",
    "        self.node = None\n",
    "\n",
    "        self.a2a_fwd_message_buf = []\n",
    "\n",
    "    def init_frequent_embedding_cache(self, num_frequent, embedding_vec_size):\n",
    "        self.num_frequent = num_frequent\n",
    "        #self.frequent_embedding_vectors = np.random.random((num_frequent,embedding_vec_size)) + 2\n",
    "        #self.frequent_partial_gradients = np.random.random((num_frequent,embedding_vec_size))\n",
    "        self.frequent_embedding_vectors = np.zeros((num_frequent,embedding_vec_size), dtype = np.float32)\n",
    "        self.frequent_partial_gradients = np.zeros((num_frequent,embedding_vec_size), dtype = np.float32)\n",
    "        for i in range(num_frequent):\n",
    "            for j in range(embedding_vec_size):\n",
    "                self.frequent_embedding_vectors[i, j] = float(i*embedding_vec_size + j + 1);\n",
    "    def init_infrequent_embedding_cache(self, embedding_vec_size):\n",
    "        self.num_infrequent = np.sum(self.category_location[:,0] == self.ggpu_id)\n",
    "        #self.infrequent_embedding_vectors = np.random.random((self.num_infrequent,embedding_vec_size))\n",
    "        self.infrequent_embedding_vectors = np.zeros((self.num_infrequent,embedding_vec_size), dtype = np.float32)\n",
    "        for i in range(self.num_infrequent):\n",
    "            for j in range(embedding_vec_size):\n",
    "                self.infrequent_embedding_vectors[i, j] = -1.0*float(i*embedding_vec_size + j + 1)\n",
    "\n",
    "    def init_embedding_output(self, embedding_vec_size):\n",
    "        self.embedding_output = np.zeros((int(local_batch_size*num_slots), embedding_vec_size), dtype = np.float32)\n",
    "        \n",
    "class Node:\n",
    "\n",
    "    def __init__(self, num_gpus):\n",
    "        self.gpus = [Gpu() for i in range(num_gpus)]\n",
    "        for i in range(num_gpus):\n",
    "            self.gpus[i].gpu_id = i\n",
    "            self.gpus[i].node = self # reference to this node\n",
    "\n",
    "class Network:\n",
    "\n",
    "    def __init__(self, nodes):\n",
    "        self.nodes = nodes\n",
    "\n",
    "    def all_reduce(self):\n",
    "        pass\n",
    "\n",
    "    def all_to_all(self, samples):\n",
    "        gpus = [gpu for node in self.nodes for gpu in node.gpus]\n",
    "        for srcidx,srcgpu in enumerate(gpus):\n",
    "            for dstidx, dstgpu in enumerate(gpus):\n",
    "                model_indices_to_dst = srcgpu.model_indices[srcgpu.model_indices_offsets[dstidx]:srcgpu.model_indices_offsets[dstidx+1]]\n",
    "                samples_to_dst = samples[model_indices_to_dst]\n",
    "                embedding_vec_incides = srcgpu.category_location[samples_to_dst, 1]\n",
    "                data_send = srcgpu.infrequent_embedding_vectors[embedding_vec_incides, :]\n",
    "                dstgpu.a2a_fwd_message_buf.append(data_send)\n",
    "        for gpu in gpus:\n",
    "            gpu.a2a_fwd_message_buf = np.concatenate(gpu.a2a_fwd_message_buf)\n",
    "\n",
    "    def embedding_network_forward(self, samples):\n",
    "        gpus = [gpu for node in self.nodes for gpu in node.gpus]\n",
    "        for gpuid, gpu in enumerate(gpus):\n",
    "            for i, idx in enumerate(gpu.network_indices):\n",
    "                gpu.embedding_output[idx] = gpu.a2a_fwd_message_buf[i]\n",
    "            index_base = int(local_batch_size*num_slots*gpu.ggpu_id)\n",
    "            for i, idx in enumerate(gpu.frequent_sample_indices):\n",
    "                category = samples[idx + index_base]\n",
    "                embedding_vec_index = gpu_.category_frequent_index[category]\n",
    "                gpu.embedding_output[idx] = gpu_.frequent_embedding_vectors[embedding_vec_index]\n",
    "            num_categories = gpu.category_location.shape[0]\n",
    "            for i in range(gpu.embedding_output.shape[0]):\n",
    "                category = samples[i + index_base]\n",
    "                is_frequent = (gpu.category_location[category][0] == num_categories)\n",
    "                #print(\"i = {}, is_frequent = {}, category = {}, location = {}, vec[0] = {}\".format(i, is_frequent, category, gpu.category_location[category][0], gpu.embedding_output[i][0]))\n",
    "                assert is_frequent == (gpu.embedding_output[i][0] > 0), \"wrong location of embedding vector\"\n",
    "            #pdb.set_trace()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "nodes = [Node(num_networks_per_node) for i in range(num_nodes)]\n",
    "gpus = [gpu for node in nodes for gpu in node.gpus]\n",
    "num_gpus = len(gpus)\n",
    "for i in range(num_nodes):\n",
    "    nodes[i].node_id = i\n",
    "network = Network(nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_frequent = 128\n",
    "uniques, counts = np.unique(samples, return_counts=True)\n",
    "sorted_uniques = sorted(zip(counts, uniques), key=lambda x: -x[0])\n",
    "# print(sorted_uniques)\n",
    "frequent = set([unique_ for _, unique_ in sorted_uniques[:128]])\n",
    "assert(len(frequent) == num_frequent)\n",
    "print(num_frequent, \"/\", num_categories)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# category_frequent_index\n",
    "frequent_mask = np.zeros(num_categories, dtype=bool)\n",
    "for c in frequent:\n",
    "    frequent_mask[c] = True\n",
    "category_frequent_index = num_categories * np.ones(num_categories, dtype=np.int32)\n",
    "category_frequent_index[frequent_mask] = range(num_frequent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# category_location\n",
    "num_infrequent = num_categories - num_frequent\n",
    "category_location = num_categories * np.ones((num_categories,2), dtype=np.int32)\n",
    "infrequent_index = np.zeros(num_categories)\n",
    "infrequent_mask = np.logical_not(frequent_mask)\n",
    "infrequent_index[infrequent_mask] = range(num_infrequent)\n",
    "for c in range(num_categories):\n",
    "    if c not in frequent:\n",
    "        index = infrequent_index[c]\n",
    "        category_location[c,:] = [index % num_gpus, index // num_gpus]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, gpu in enumerate(gpus):\n",
    "    gpu.category_location = category_location\n",
    "    gpu.ggpu_id = idx;\n",
    "    gpu.category_frequent_index = category_frequent_index\n",
    "    gpu.init_frequent_embedding_cache(num_frequent, embedding_vec_size)\n",
    "    gpu.init_infrequent_embedding_cache(embedding_vec_size)\n",
    "    gpu.init_embedding_output(embedding_vec_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_display = 20\n",
    "class color:\n",
    "   PURPLE = '\\033[95m'\n",
    "   CYAN = '\\033[96m'\n",
    "   DARKCYAN = '\\033[36m'\n",
    "   BLUE = '\\033[94m'\n",
    "   GREEN = '\\033[92m'\n",
    "   YELLOW = '\\033[93m'\n",
    "   RED = '\\033[91m'\n",
    "   BOLD = '\\033[1m'\n",
    "   UNDERLINE = '\\033[4m'\n",
    "   END = '\\033[0m'\n",
    "print(f'{color.BOLD}{color.RED}category          |-> frequent category cache index{color.END}')\n",
    "for category in range(n_display):\n",
    "    frequent_category_cache_index = category_frequent_index[category]\n",
    "    if frequent_category_cache_index < num_categories:\n",
    "        print(f'category {color.BOLD}{category:3d}{color.END}      |->  cache index {color.BOLD}{color.BLUE}{frequent_category_cache_index:6,d}{color.END}')\n",
    "    else:\n",
    "        print(f'category {color.BOLD}{category:3d}{color.END}      |->  cache index    {color.BOLD}{color.RED}END{color.END}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_display = 20\n",
    "print(f'{color.BOLD}{color.RED}category          |->  category location {color.END}')\n",
    "for category in range(n_display):\n",
    "    location = category_location[category,:]\n",
    "    if location[0] < num_categories:\n",
    "        print(f'category {color.BOLD}{category:3d}{color.END}      |->  category location {color.BOLD}{color.GREEN}{location}{color.END}')\n",
    "    else:\n",
    "        print(f'category {color.BOLD}{category:3d}{color.END}      |->  category location   {color.BOLD}{color.RED}END{color.END}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert np.sum(category_location[:,0] != num_categories) == num_infrequent\n",
    "assert np.sum(category_frequent_index != num_categories) == num_frequent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Index calculations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bisect import bisect_left\n",
    "\n",
    "def get_node_gpu(node_id, gpu_id):\n",
    "    # not efficient, but that's not the point here! :P\n",
    "    node = None\n",
    "    gpu = None\n",
    "    for node_ in nodes:\n",
    "        if node_.node_id == node_id:\n",
    "            node = node_\n",
    "            break\n",
    "    for gpu_ in node.gpus:\n",
    "        if gpu_.gpu_id == gpu_id:\n",
    "            gpu = gpu_\n",
    "            break\n",
    "    return node, gpu\n",
    "\n",
    "def get_network_id(node_id, gpu_id):\n",
    "    for i in range(len(gpus)):\n",
    "        if gpus[i].node.node_id == node_id and gpus[i].gpu_id == gpu_id:\n",
    "            return i\n",
    "    raise KeyError(f\"Not found: node {node_id}, GPU {gpu_id}\")\n",
    "\n",
    "def cub_DeviceSelect(gpu, samples, network_id):\n",
    "    location = gpu.category_location[samples,:]\n",
    "    samples_mask = (location[:,0] == network_id)\n",
    "    samples_filter = np.r_[:samples.size][samples_mask]\n",
    "    return samples_filter\n",
    "\n",
    "# model indices: forward-send, backward-receive\n",
    "def calculate_model_indices(samples, node_id, gpu_id):\n",
    "    _, gpu = get_node_gpu(node_id, gpu_id)\n",
    "    network_id = get_network_id(node_id, gpu_id)\n",
    "    section_size = samples.size // num_gpus\n",
    "\n",
    "    sample_model_indices = cub_DeviceSelect(gpu, samples, network_id)\n",
    "    network_offset_model_indices = np.zeros(num_gpus + 1, dtype=np.int32)\n",
    "    for i in range(num_gpus):\n",
    "        network_offset_model_indices[i] = bisect_left(sample_model_indices, i * section_size)\n",
    "    network_offset_model_indices[-1] = sample_model_indices.size\n",
    "\n",
    "    return sample_model_indices, network_offset_model_indices\n",
    "\n",
    "# network indices: forward-receive, backward-send\n",
    "def calculate_network_indices(samples, node_id, gpu_id):\n",
    "    _, gpu = get_node_gpu(node_id, gpu_id)\n",
    "\n",
    "    section_size = samples.size // num_gpus\n",
    "    network_id = get_network_id(node_id, gpu_id)\n",
    "    start_idx = network_id * section_size\n",
    "    end_idx = min((network_id + 1) * section_size, samples.size)\n",
    "    sub_batch = samples[start_idx:end_idx]\n",
    "\n",
    "    location = gpu.category_location[sub_batch,:]\n",
    "    samples_mask = location[:,0] < num_categories\n",
    "    infrequent_indices = deepcopy(np.r_[:sub_batch.size][samples_mask])\n",
    "    network_indices = deepcopy(location[:, 0][samples_mask])\n",
    "    sorted_indices = np.array(sorted(zip(network_indices, infrequent_indices),\n",
    "                                     key=lambda x: x[0]))\n",
    "\n",
    "    sample_network_offsets = np.zeros(num_gpus + 1, dtype=np.int32)\n",
    "    if len(network_indices):\n",
    "        sample_network_indices = sorted_indices[:,1]\n",
    "        for i in range(num_gpus):\n",
    "            sample_network_offsets[i] = bisect_left(sorted_indices[:,0], i)\n",
    "    else:\n",
    "        sample_network_indices = np.zeros(0)\n",
    "    sample_network_offsets[-1] = len(network_indices)\n",
    "    \n",
    "    return sample_network_indices, sample_network_offsets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iteration = 0\n",
    "batch = flatten_data([data[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_indices = {}\n",
    "model_indices_offsets = {}\n",
    "for node_ in nodes:\n",
    "    for gpu_ in node_.gpus:\n",
    "        node_id = node_.node_id\n",
    "        gpu_id = gpu_.gpu_id\n",
    "        idx, off = calculate_model_indices(batch, node_id, gpu_id)\n",
    "        model_indices[(node_id, gpu_id)] = idx\n",
    "        model_indices_offsets[(node_id, gpu_id)] = off\n",
    "        gpu_.model_indices = idx\n",
    "        gpu_.model_indices_offsets = off\n",
    "\n",
    "# print(model_indices)\n",
    "# print(model_indices_offsets)\n",
    "\n",
    "network_indices = {}\n",
    "network_indices_offsets = {}\n",
    "for node_ in nodes:\n",
    "    for gpu_ in node_.gpus:\n",
    "        node_id = node_.node_id\n",
    "        gpu_id = gpu_.gpu_id\n",
    "        idx, off = calculate_network_indices(batch, node_id, gpu_id)\n",
    "        network_indices[(node_id, gpu_id)] = idx\n",
    "        network_indices_offsets[(node_id, gpu_id)] = off\n",
    "        gpu_.network_indices = idx\n",
    "        gpu_.network_indices_offsets = off\n",
    "\n",
    "# print(network_indices)\n",
    "# print(network_indices_offsets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# frequent sample indices\n",
    "def calculate_frequent_sample_indices(samples, node_id, gpu_id):\n",
    "    _, gpu = get_node_gpu(node_id, gpu_id)\n",
    "    \n",
    "    section_size = samples.size // num_gpus\n",
    "    network_id = get_network_id(node_id, gpu_id)\n",
    "    start_idx = network_id * section_size\n",
    "    end_idx = min((network_id + 1) * section_size, samples.size)\n",
    "    sub_batch = samples[start_idx:end_idx]\n",
    "\n",
    "    freq_indices = gpu.category_frequent_index[sub_batch]\n",
    "    samples_mask = freq_indices < num_categories\n",
    "    frequent_sample_indices = deepcopy(np.r_[:sub_batch.size][samples_mask])\n",
    "    \n",
    "    return frequent_sample_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "frequent_sample_indices = {}\n",
    "for node_ in nodes:\n",
    "    for gpu_ in node_.gpus:\n",
    "        node_id = node_.node_id\n",
    "        gpu_id = gpu_.gpu_id\n",
    "        frequent_sample_indices[(node_id, gpu_id)] = \\\n",
    "            calculate_frequent_sample_indices(batch, node_id, gpu_id)\n",
    "        gpu_.frequent_sample_indices = frequent_sample_indices[(node_id, gpu_id)]\n",
    "\n",
    "# print(frequent_sample_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "network.all_to_all(batch)\n",
    "network.embedding_network_forward(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_model_cache_indices(samples, node_id, gpu_id):\n",
    "    _, gpu = get_node_gpu(node_id, gpu_id)\n",
    "    \n",
    "    section_size = samples.size // num_gpus\n",
    "    model_id = get_network_id(node_id, gpu_id)\n",
    "    num_frequent_per_model = num_frequent // num_gpus\n",
    "    \n",
    "    # Compute the mask\n",
    "    network_frequent_mask = np.zeros(num_gpus * num_frequent_per_model, dtype=bool)\n",
    "    for i in range(num_gpus):\n",
    "        freq_index = gpu.category_frequent_index[\n",
    "            samples[i * section_size:(i+1)*section_size]]\n",
    "        for idx in freq_index:\n",
    "            if idx < num_frequent and idx // num_frequent_per_model == model_id:\n",
    "                network_frequent_mask[i * num_frequent_per_model\n",
    "                                      + idx % num_frequent_per_model] = True\n",
    "  \n",
    "    # Select categories according to the mask\n",
    "    model_cache_indices = np.r_[:num_gpus * num_frequent_per_model][network_frequent_mask]\n",
    "    \n",
    "    # Compute offsets\n",
    "    model_cache_indices_offsets = np.zeros(num_gpus + 1, dtype=np.int32)\n",
    "    for i in range(num_gpus):\n",
    "        model_cache_indices_offsets[i] = bisect_left(model_cache_indices, i * num_frequent_per_model)\n",
    "    model_cache_indices_offsets[num_gpus] = model_cache_indices.size\n",
    "    \n",
    "    # Convert to buffer indices\n",
    "    model_cache_indices = (model_cache_indices % num_frequent_per_model\n",
    "                           + model_id * num_frequent_per_model)\n",
    "    \n",
    "    return model_cache_indices, model_cache_indices_offsets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_cache_indices = {}\n",
    "model_cache_indices_offsets = {}\n",
    "for node_ in nodes:\n",
    "    for gpu_ in node_.gpus:\n",
    "        node_id = node_.node_id\n",
    "        gpu_id = gpu_.gpu_id\n",
    "        idx, off = calculate_model_cache_indices(batch, node_id, gpu_id)\n",
    "        model_cache_indices[(node_id, gpu_id)] = idx\n",
    "        model_cache_indices_offsets[(node_id, gpu_id)] = off\n",
    "\n",
    "# print(model_cache_indices)\n",
    "# print(model_cache_indices_offsets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_network_cache_indices(samples, node_id, gpu_id):\n",
    "    _, gpu = get_node_gpu(node_id, gpu_id)\n",
    "    model_id = get_network_id(node_id, gpu_id)\n",
    "    network_mask = np.zeros(num_frequent, dtype=bool)\n",
    "    section_size = samples.size // num_gpus\n",
    "    sample_mlp_batch = samples[model_id * section_size:min(samples.size,(model_id + 1)*section_size)]\n",
    "    freq_index = gpu.category_frequent_index[sample_mlp_batch]\n",
    "    for index in freq_index:\n",
    "        if index < num_frequent:\n",
    "            network_mask[index] = True\n",
    "    network_cache_indices = np.r_[:num_frequent][network_mask]\n",
    "    \n",
    "    # Compute offsets\n",
    "    num_frequent_per_model = num_frequent // num_gpus\n",
    "    network_cache_indices_offsets = np.zeros(num_gpus + 1, dtype=np.int32)\n",
    "    for i in range(num_gpus):\n",
    "        network_cache_indices_offsets[i] = bisect_left(network_cache_indices, i * num_frequent_per_model)\n",
    "    network_cache_indices_offsets[num_gpus] = network_cache_indices.size\n",
    "\n",
    "    return network_cache_indices, network_cache_indices_offsets, network_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "network_cache_indices = {}\n",
    "network_cache_indices_offsets = {}\n",
    "network_cache_masks = {}\n",
    "for node_ in nodes:\n",
    "    for gpu_ in node_.gpus:\n",
    "        node_id = node_.node_id\n",
    "        gpu_id = gpu_.gpu_id\n",
    "        idx, off, mask = calculate_network_cache_indices(batch, node_id, gpu_id)\n",
    "        network_cache_indices[(node_id, gpu_id)] = idx\n",
    "        network_cache_indices_offsets[(node_id, gpu_id)] = off\n",
    "        network_cache_masks[(node_id, gpu_id)] = mask"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate arrays"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(list(np.reshape(category_location, 2*num_categories)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(list(category_frequent_index))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(list(batch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(model_indices.keys()):\n",
    "    print(\"{%s},\" % ','.join(map(str, model_indices[key])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(model_indices_offsets.keys()):\n",
    "    print(\"{%s},\" % ','.join(map(str, model_indices_offsets[key])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(network_indices.keys()):\n",
    "    print(\"{%s},\" % ','.join(map(str, network_indices[key])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(network_indices_offsets.keys()):\n",
    "    print(\"{%s},\" % ','.join(map(str, network_indices_offsets[key])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(frequent_sample_indices.keys()):\n",
    "    print(\"{%s},\" % ','.join(map(str, frequent_sample_indices[key])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(model_indices.keys()):\n",
    "     _, gpu = get_node_gpu(key[0], key[1])\n",
    "     print(\"{%s},\" % ','.join(map(str, gpu.embedding_output.reshape(gpu.embedding_output.size))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(model_cache_indices.keys()):\n",
    "    print(\"{%s},\" % ','.join(map(str, model_cache_indices[key])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(model_cache_indices_offsets.keys()):\n",
    "    print(\"{%s},\" % ','.join(map(str, model_cache_indices_offsets[key])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(model_indices.keys()):\n",
    "     _, gpu = get_node_gpu(key[0], key[1])\n",
    "     print(\"{%s},\" % ','.join(map(str, gpu.a2a_fwd_message_buf.reshape(gpu.a2a_fwd_message_buf.size))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(network_cache_indices.keys()):\n",
    "    print(\"{%s},\" % ','.join(map(str, network_cache_indices[key])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(network_cache_indices_offsets.keys()):\n",
    "    print(\"{%s},\" % ','.join(map(str, network_cache_indices_offsets[key])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in sorted(network_cache_masks.keys()):\n",
    "    print(\"{%s},\" % ','.join(map(lambda x: str(int(x)), network_cache_masks[key])))"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
